Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[codegen] Add max(half, half) support when enable fp16 #3811

Closed
wants to merge 2 commits into from

Conversation

ZQPei
Copy link
Contributor

@ZQPei ZQPei commented Aug 21, 2019

Fix the following error when compiled with float16 model in cuda.

/tmp/tmpz_0pydlm/my_kernel.cu(9890): error: more than one instance of overloaded function "max" matches the argument list:
            function "max(int, int)"
            function "max(unsigned int, unsigned int)"
            function "max(int, unsigned int)"
            function "max(unsigned int, int)"
            function "max(long, long)"
            function "max(unsigned long, unsigned long)"
            function "max(long, unsigned long)"
            function "max(unsigned long, long)"
            function "max(long long, long long)"
            function "max(unsigned long long, unsigned long long)"
            function "max(long long, unsigned long long)"
            function "max(unsigned long long, long long)"
            function "max(float, float)"
            argument types are: (half, __half)

Please check!

@cchung100m
Copy link
Contributor

Hi @ZQPei

Please check the CI error

docker/bash.sh tvmai/ci-lint:v0.51 ./tests/scripts/task_lint.sh

Makefile:70: recipe for target 'cpplint' failed

make: *** [cpplint] Error 1

script returned exit code 2

@ZQPei
Copy link
Contributor Author

ZQPei commented Aug 21, 2019

Hi @cchung100m
Now I have passed all 5 checks
please check
Thanks

@cchung100m
Copy link
Contributor

Hi @ZQPei

For your PR, besides codegen_cuda.cc, we should also need to add the unittest in tests/python/unittest/test_codegen_cuda.py.

@ZQPei
Copy link
Contributor Author

ZQPei commented Aug 21, 2019

How to add a unittest?
Do I need to do that?

@cchung100m
Copy link
Contributor

Hi @ZQPei

In my opinion, please refer to the post at the tvm forum to develop your testing script.

hope helps,

@ZQPei
Copy link
Contributor Author

ZQPei commented Aug 22, 2019

Hi, @cchung100m
I have wrote a unit test for this PR.
Test was passed on my machine.
Should I add this code to TVM?

Here is the unit test.

def test_cuda_vector_max():
    num_thread = 8
    target = 'cuda'
    def check_vector_max(ctx, n, dtype):
        if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
            print("skip because cuda is not enabled..")
            return
        if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
            print("skip because gpu does not support fp16")
            return
        A = tvm.placeholder((n,), name='A', dtype=dtype)
        B = tvm.placeholder((n,), name='B', dtype=dtype)
        C = tvm.compute((n,), lambda i: tvm.max(A[i], B[i]), name='C')
        s = tvm.create_schedule(C.op)
        bx, tx = s[C].split(C.op.axis[0], factor=num_thread)
        s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
        s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
        fun = tvm.build(s, [A,B,C], "cuda", name="vector_max")

        np_a = np.random.uniform(size=n).astype(dtype)
        np_b = np.random.uniform(size=n).astype(dtype)
        np_c = np.maximum(np_a, np_b)
        a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(np_a)
        b = tvm.nd.empty((n,), B.dtype, ctx).copyfrom(np_b)
        c = tvm.nd.empty((n,), C.dtype, ctx)
        fun(a, b, c)
        np.testing.assert_equal(c.asnumpy(), np_c)

    ctx = tvm.context(target, 0)
    check_vector_max(ctx, 10, "float32")
    check_vector_max(ctx, 10, "float16")

Copy link
Contributor

@anijain2305 anijain2305 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -50,6 +50,8 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) {
std::string CodeGenCUDA::Finish() {
if (enable_fp16_) {
decl_stream << "#include <cuda_fp16.h>\n";
decl_stream << "__device__ half max(const half a, const half b)\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know which operators we have to overload as such? "max" is one of them. Do we need others?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I only find max that need to be overloaded.

BTW, I have a question about the checks.
Why this commit cannot be built today? It was successful yesterday.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, we saw more failures while trying to run full resnet

#3816 (comment)

I think, we are missing all reduce ops. Will it be possible for you to help with this? (In a separate PR, this one is good to go)

@ZQPei
Copy link
Contributor Author

ZQPei commented Aug 22, 2019

Hi @cchung100m @anijain2305
The commit to codegen_cuda.cc is passed yesterday.
However, the commit of adding unittest to test_codegen_cuda.py was failed today.
Can you help me to find out what caused the check failed?

@ZQPei ZQPei force-pushed the master branch 2 times, most recently from 1dc314e to 09ea9ab Compare August 22, 2019 09:17
Fix the follow error when compiled with float16 model.
```
/tmp/tmpz_0pydlm/my_kernel.cu(9890): error: more than one instance of overloaded function "max" matches the argument list:
            function "max(int, int)"
            function "max(unsigned int, unsigned int)"
            function "max(int, unsigned int)"
            function "max(unsigned int, int)"
            function "max(long, long)"
            function "max(unsigned long, unsigned long)"
            function "max(long, unsigned long)"
            function "max(unsigned long, long)"
            function "max(long long, long long)"
            function "max(unsigned long long, unsigned long long)"
            function "max(long long, unsigned long long)"
            function "max(unsigned long long, long long)"
            function "max(float, float)"
            argument types are: (half, __half)
```

add max(half, half) support when enable fp16

fix cpplint error.

add max(half, half) support when enable fp16

fix cpplint error, replace tab with whitespace

add unittest for vector_max

add unittest for vector_max

add max(half, half) support when enable fp16

add max(half, half) support when enable fp16
@vinx13
Copy link
Member

vinx13 commented Aug 22, 2019

@ZQPei please also add the test case to this PR

@cchung100m
Copy link
Contributor

Hi @ZQPei

Please first post in discuss.tvm.ai and provide more details of what you're doing. Currently, it is not clear where the problem is.

@anijain2305
Copy link
Contributor

HI, just a ping to get this in :)

@tqchen
Copy link
Member

tqchen commented Oct 24, 2019

superceded by #4056. Thanks @ZQPei for your contribution

@tqchen tqchen closed this Oct 24, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants